{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Minimal tutorial on packing and unpacking sequences in PyTorch\n",
    "# aka how to use `pack_padded_sequence` and  `pad_packed_sequence`\n",
    "\n",
    "This is a jupyter version of [@Tushar-N 's gist](https://gist.github.com/Tushar-N/dfca335e370a2bc3bc79876e6270099e) with comments from [@Harsh Trivedi repo](https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial\n",
    "import torch\n",
    "from torch import LongTensor\n",
    "from torch.nn import Embedding, LSTM\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "\n",
    "## We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']\n",
    "#\n",
    "#     Step 1: Construct Vocabulary\n",
    "#     Step 2: Load indexed data (list of instances, where each instance is list of character indices)\n",
    "#     Step 3: Make Model\n",
    "#  *  Step 4: Pad instances with 0s till max length sequence\n",
    "#  *  Step 5: Sort instances by sequence length in descending order\n",
    "#  *  Step 6: Embed the instances\n",
    "#  *  Step 7: Call pack_padded_sequence with embeded instances and sequence lengths\n",
    "#  *  Step 8: Forward with LSTM\n",
    "#  *  Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector\n",
    "#  *  Summary of Shape Transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We want to run LSTM on a batch following 3 character sequences\n",
    "seqs = ['long_str',  # len = 8\n",
    "        'tiny',      # len = 4\n",
    "        'medium']    # len = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 1: Construct Vocabulary ##\n",
    "##------------------------------##\n",
    "# make sure <pad> idx is 0\n",
    "vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 2: Load indexed data (list of instances, where each instance is list of character indices) ##\n",
    "##-------------------------------------------------------------------------------------------------##\n",
    "vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorized_seqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 3: Make Model ##\n",
    "##--------------------##\n",
    "embed = Embedding(len(vocab), 4) # embedding_dim = 4\n",
    "lstm = LSTM(input_size=4, hidden_size=5, num_layers=2, batch_first=True) # input_dim = 4, hidden_dim = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 4: Pad instances with 0s till max length sequence ##\n",
    "##--------------------------------------------------------##\n",
    "\n",
    "# get the length of each seq in your batch\n",
    "seq_lengths = LongTensor(list(map(len, vectorized_seqs)))\n",
    "# seq_lengths => [ 8, 4,  6]\n",
    "# batch_sum_seq_len: 8 + 4 + 6 = 18\n",
    "# max_seq_len: 8\n",
    "\n",
    "seq_tensor = (torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()\n",
    "# seq_tensor => [[0 0 0 0 0 0 0 0]\n",
    "#                [0 0 0 0 0 0 0 0]\n",
    "#                [0 0 0 0 0 0 0 0]]\n",
    "\n",
    "for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):\n",
    "    seq_tensor[idx, :seqlen] = LongTensor(seq)\n",
    "# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str\n",
    "#                [12  5  8 14  0  0  0  0]          # tiny\n",
    "#                [ 7  3  2  5 13  7  0  0]]         # medium\n",
    "# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 5: Sort instances by sequence length in descending order ##\n",
    "##---------------------------------------------------------------##\n",
    "\n",
    "seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)\n",
    "seq_tensor = seq_tensor[perm_idx]\n",
    "# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 6: Embed the instances ##\n",
    "##-----------------------------##\n",
    "\n",
    "embedded_seq_tensor = embed(seq_tensor)\n",
    "# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedded_seq_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedded_seq_tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 7: Call pack_padded_sequence with embeded instances and sequence lengths ##\n",
    "##-------------------------------------------------------------------------------##\n",
    "\n",
    "packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)\n",
    "# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes\n",
    "# packed_input.data.shape : (batch_sum_seq_len X embedding_dim) = (18 X 4)\n",
    "#\n",
    "# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]\n",
    "# visualization :\n",
    "# l  o  n  g  _  s  t  r   #(long_str)\n",
    "# m  e  d  i  u  m         #(medium)\n",
    "# t  i  n  y               #(tiny)\n",
    "# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "packed_input.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 8: Forward with LSTM ##\n",
    "##---------------------------##\n",
    "\n",
    "packed_output, (ht, ct) = lstm(packed_input)\n",
    "# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes\n",
    "# packed_output.data.shape : (batch_sum_seq_len X hidden_dim) = (18 X 5)\n",
    "\n",
    "# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)\n",
    "# visualization :\n",
    "# l  o  n  g  _  s  t  r   #(long_str)\n",
    "# m  e  d  i  u  m         #(medium)\n",
    "# t  i  n  y               #(tiny)\n",
    "# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "packed_output.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ht"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ht.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ct.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector ##\n",
    "##------------------------------------------------------------------------------------##\n",
    "\n",
    "# unpack your output if required\n",
    "output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)\n",
    "# output:\n",
    "# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Or if you just want the final hidden state?\n",
    "print(ht[-1])\n",
    "\n",
    "## Summary of Shape Transformations ##\n",
    "##----------------------------------##\n",
    "\n",
    "# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)\n",
    "# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)\n",
    "# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)\n",
    "# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
